import json
import os
from copy import deepcopy

import numpy as np
# import torch
# import torchvision.transforms as T
# from FlagEmbedding import BGEM3FlagModel
from marker.config.parser import ConfigParser
from marker.converters.pdf import PdfConverter
from marker.output import text_from_rendered
from PIL import Image
# from torchvision.transforms.functional import InterpolationMode
# from transformers import AutoFeatureExtractor, AutoModel

# from utils.src.presentation import Presentation, SlidePage
# from utils.src.utils import is_image_path, pjoin
pjoin = os.path.join

# device_count = torch.cuda.device_count()


# def prs_dedup(
#     presentation: Presentation,
#     model: BGEM3FlagModel,
#     batchsize: int = 32,
#     threshold: float = 0.8,
# ) -> list[SlidePage]:
#     """
#     Deduplicate slides in a presentation based on text similarity.

#     Args:
#         presentation (Presentation): The presentation object containing slides.
#         model: The model used for generating text embeddings.
#         batchsize (int): The batch size for processing slides.
#         threshold (float): The similarity threshold for deduplication.

#     Returns:
#         list: A list of removed duplicate slides.
#     """
#     text_embeddings = get_text_embedding(
#         [i.to_text() for i in presentation.slides], model, batchsize
#     )
#     pre_embedding = text_embeddings[0]
#     slide_idx = 1
#     duplicates = []
#     while slide_idx < len(presentation):
#         cur_embedding = text_embeddings[slide_idx]
#         if torch.cosine_similarity(pre_embedding, cur_embedding, -1) > threshold:
#             duplicates.append(slide_idx - 1)
#         slide_idx += 1
#         pre_embedding = cur_embedding
#     return [presentation.slides.pop(i) for i in reversed(duplicates)]


# def get_text_model(device: str = None) -> BGEM3FlagModel:
#     """
#     Initialize and return a text model.

#     Args:
#         device (str): The device to run the model on.

#     Returns:
#         BGEM3FlagModel: The initialized text model.
#     """
#     return BGEM3FlagModel(
#         "BAAI/bge-m3",
#         use_fp16=True,
#         device=device,
#     )


# def get_image_model(device: str = None):
#     """
#     Initialize and return an image model and its feature extractor.

#     Args:
#         device (str): The device to run the model on.

#     Returns:
#         tuple: A tuple containing the feature extractor and the image model.
#     """
#     model_base = "google/vit-base-patch16-224-in21k"
#     return (
#         AutoFeatureExtractor.from_pretrained(
#             model_base,
#             torch_dtype=torch.float16,
#             device_map=device,
#         ),
#         AutoModel.from_pretrained(
#             model_base,
#             torch_dtype=torch.float16,
#             device_map=device,
#         ).eval(),
#     )


def parse_pdf(
    pdf_path: str,
    output_path: str = None,
    model_lst: list = None,
    save_file: bool = True,
) -> str:
    """
    Parse a PDF file and extract text and images.

    Args:
        pdf_path (str): The path to the PDF file.
        output_path (str): The directory to save the extracted content.
        model_lst (list): A list of models for processing the PDF.

    Returns:
        str: The full text extracted from the PDF.
    """
    if save_file:
        os.makedirs(output_path, exist_ok=True)
    config_parser = ConfigParser(
        {
            "output_format": "markdown",
        }
    )
    converter = PdfConverter(
        config=config_parser.generate_config_dict(),
        artifact_dict=model_lst,
        processor_list=config_parser.get_processors(),
        renderer=config_parser.get_renderer(),
    )
    rendered = converter(pdf_path)
    full_text, _, images = text_from_rendered(rendered)
    if save_file:
        with open(pjoin(output_path, "source.md"), "w+", encoding="utf-8") as f:
            f.write(full_text)
        for filename, image in images.items():
            image_filepath = os.path.join(output_path, filename)
            image.save(image_filepath, "JPEG")
        with open(pjoin(output_path, "meta.json"), "w+") as f:
            f.write(json.dumps(rendered.metadata, indent=4))

    if not save_file:
        return full_text, rendered
    return full_text


# def get_text_embedding(
#     text: list[str], model: BGEM3FlagModel, batchsize: int = 32
# ) -> list[torch.Tensor]:
#     """
#     Generate text embeddings for a list of text strings.

#     Args:
#         text (list[str]): A list of text strings.
#         model: The model used for generating embeddings.
#         batchsize (int): The batch size for processing text.

#     Returns:
#         list: A list of text embeddings.
#     """
#     if isinstance(text, str):
#         return torch.tensor(model.encode(text)["dense_vecs"]).to(model.device)
#     result = []
#     for i in range(0, len(text), batchsize):
#         result.extend(
#             torch.tensor(model.encode(text[i : i + batchsize])["dense_vecs"]).to(
#                 model.device
#             )
#         )
#     return result


# def get_image_embedding(
#     image_dir: str, extractor, model, batchsize: int = 16
# ) -> dict[str, torch.Tensor]:
#     """
#     Generate image embeddings for images in a directory.

#     Args:
#         image_dir (str): The directory containing images.
#         extractor: The feature extractor for images.
#         model: The model used for generating embeddings.
#         batchsize (int): The batch size for processing images.

#     Returns:
#         dict: A dictionary mapping image filenames to their embeddings.
#     """
#     transform = T.Compose(
#         [
#             T.Resize(int((256 / 224) * extractor.size["height"])),
#             T.CenterCrop(extractor.size["height"]),
#             T.ToTensor(),
#             T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
#         ]
#     )

#     inputs = []
#     embeddings = []
#     images = [i for i in sorted(os.listdir(image_dir)) if is_image_path(i)]
#     for file in images:
#         image = Image.open(pjoin(image_dir, file)).convert("RGB")
#         inputs.append(transform(image))
#         if len(inputs) % batchsize == 0 or file == images[-1]:
#             batch = {"pixel_values": torch.stack(inputs).to(model.device)}
#             embeddings.extend(model(**batch).last_hidden_state.detach())
#             inputs.clear()
#     return {image: embedding.flatten() for image, embedding in zip(images, embeddings)}


# def images_cosine_similarity(embeddings: list[torch.Tensor]) -> torch.Tensor:
#     """
#     Calculate the cosine similarity matrix for a list of embeddings.
#     Args:
#         embeddings (list[torch.Tensor]): A list of image embeddings.

#     Returns:
#         torch.Tensor: A NxN similarity matrix.
#     """
#     embeddings = [embedding for embedding in embeddings]
#     sim_matrix = torch.zeros((len(embeddings), len(embeddings)))
#     for i in range(len(embeddings)):
#         for j in range(i + 1, len(embeddings)):
#             sim_matrix[i, j] = sim_matrix[j, i] = torch.cosine_similarity(
#                 embeddings[i], embeddings[j], -1
#             )
#     return sim_matrix


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


# def average_distance(
#     similarity: torch.Tensor, idx: int, cluster_idx: list[int]
# ) -> float:
#     """
#     Calculate the average distance between a point (idx) and a cluster (cluster_idx).

#     Args:
#         similarity (np.ndarray): The similarity matrix.
#         idx (int): The index of the point.
#         cluster_idx (list): The indices of the cluster.

#     Returns:
#         float: The average distance.
#     """
#     if idx in cluster_idx:
#         return 0
#     total_similarity = 0
#     for idx_in_cluster in cluster_idx:
#         total_similarity += similarity[idx, idx_in_cluster]
#     return total_similarity / len(cluster_idx)


# def get_cluster(similarity: np.ndarray, sim_bound: float = 0.65):
#     """
#     Cluster points based on similarity.

#     Args:
#         similarity (np.ndarray): The similarity matrix.
#         sim_bound (float): The similarity threshold for clustering.

#     Returns:
#         list: A list of clusters.
#     """
#     num_points = similarity.shape[0]
#     clusters = []
#     sim_copy = deepcopy(similarity)
#     added = [False] * num_points
#     while True:
#         max_avg_dist = sim_bound
#         best_cluster = None
#         best_point = None

#         for c in clusters:
#             for point_idx in range(num_points):
#                 if added[point_idx]:
#                     continue
#                 avg_dist = average_distance(sim_copy, point_idx, c)
#                 if avg_dist > max_avg_dist:
#                     max_avg_dist = avg_dist
#                     best_cluster = c
#                     best_point = point_idx

#         if best_point is not None:
#             best_cluster.append(best_point)
#             added[best_point] = True
#             similarity[best_point, :] = 0
#             similarity[:, best_point] = 0
#         else:
#             if similarity.max() < sim_bound:
#                 break
#             i, j = np.unravel_index(np.argmax(similarity), similarity.shape)
#             clusters.append([int(i), int(j)])
#             added[i] = True
#             added[j] = True
#             similarity[i, :] = 0
#             similarity[:, i] = 0
#             similarity[j, :] = 0
#             similarity[:, j] = 0
#     return clusters
